

import os
from collections import OrderedDict

import torch
from torch.nn.modules import Sequential

from .backbone import build_backbone
from .heads import build_linear_head,build_ra_head2, build_ra_head3
from cbml_benchmark.utils.freeze_bn import freeze

def build_model(cfg):
    backbone = build_backbone(cfg)

    # if 'mit' in cfg.MODEL.BACKBONE.NAME or 'deit' in cfg.MODEL.BACKBONE.NAME:
    #     freeze(backbone, 0)

    head = build_linear_head(cfg)
    localhead2 = build_ra_head2(cfg)
    localhead3 = build_ra_head3(cfg)

    model = Sequential(OrderedDict([
        ('backbone', backbone),
        ('localhead2', localhead2),
        ('localhead3', localhead3),
        ('finalhead',head)
    ]))

    if cfg.MODEL.PRETRAIN == 'imagenet':
        print('Loading imagenet pretrianed model ...')
        pretrained_path = os.path.expanduser(cfg.MODEL.PRETRIANED_PATH[cfg.MODEL.BACKBONE.NAME])
        model.backbone.load_param(pretrained_path)
    elif os.path.exists(cfg.MODEL.PRETRAIN) and (('mit' not in cfg.MODEL.BACKBONE.NAME) or 'output' in cfg.MODEL.BACKBONE.NAME):
        ckp = torch.load(cfg.MODEL.PRETRAIN)
        model.load_state_dict(ckp['model'],strict=False)
    elif 'mit' in cfg.MODEL.BACKBONE.NAME and os.path.exists(cfg.MODEL.PRETRAIN):
        model.backbone.init_weights(cfg.MODEL.PRETRAIN)
    return model
